import numpy as np
import faiss
from transformers import BertTokenizer, BertModel
import torch
import json
from scipy.stats import mode
import warnings
from sklearn.metrics import precision_score
warnings.filterwarnings("ignore")

def most_frequent_element(arr):
    modes, counts = mode(arr, axis=1)
    modes=modes.flatten()
    return  modes
def get_bert_embedding(instructions):
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    model = BertModel.from_pretrained("bert-base-uncased").to(torch.device("cpu"))
    encoded_input_all = [tokenizer(text, return_tensors='pt',truncation=True,
                max_length=512).to(torch.device("cpu")) for text in instructions]
    with torch.no_grad():
        emb_list = []
        for inter in  encoded_input_all:
            emb = model(**inter)
            emb_list.append(emb['last_hidden_state'].mean(1))
    return emb_list

def neiborhood_search(corpus_data,query_data):
    d = 768  # dimension
    neiborhood_num=50
    xq = torch.cat(query_data, 0).cpu().numpy()
    xb = torch.cat(corpus_data, 0).cpu().numpy()
    index = faiss.IndexFlatL2(d)  # build the index
    index.add(xb)  # add vectors to the index
    D, I = index.search(xq, neiborhood_num)

    return I

if __name__ == "__main__":
    data_all = open("data_1000.json", "r")
    data_all = json.load(data_all)
    label_all = open("label_1000.json", "r")
    label_all = json.load(label_all)

    ## data and label
    data_100=data_all[:100]
    label_100=np.array(label_all[:100])

    data_900=data_all[100:]
    label_900=label_all[100:]
    ## bert embedding
    bert_embedding_all=get_bert_embedding(data_all)
    bert_embedding_100=bert_embedding_all[:100]
    bert_embedding_900=bert_embedding_all[100:]


    neiborhood=neiborhood_search(bert_embedding_100,bert_embedding_900)
    neiborhood_index=label_100[neiborhood]
    prediction=most_frequent_element(neiborhood_index)

    precision = precision_score(np.array(label_900), prediction, average='macro')

    print(f'Precision: {precision}')

